"""
disentanglement 从sequence的角度提供了分析方法。
beta的增大一方面减少了MI,另一方面增加了有序程度。
"""
#matplotlib.use('webagg')
import argparse
from collections import defaultdict

from torch import optim
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models.anneal import get_anneal
from disvae.models.losses import get_loss_s, _get_log_pz_qz_prodzi_qzCx
from disvae.models.vae import VAE
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from utils.visualize import *
from exps.patterns import data_generator
import numpy as np

from utils.helpers import set_seed

hyperparameter_defaults = dict(
    learning_rate=0.001,
    epochs=501,
    dimension=1,
    loss='betaH',
    beta=10,
    img_id=0,
    random_seed=224
)
parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()
config = args

set_seed(config.random_seed)
wandb.init(project="exps", config=config, group='sequence')


def get_preds(model, dl):
    model.eval()
    preds = []
    targets = []
    for img, label in dl:
        with torch.no_grad():
            img = img.view(-1, 1, 64, 64).cuda()
            mu, _ = model.encoder(img)
            preds.append(mu)
            targets.append(label)
    preds, targets = torch.cat(preds), torch.cat(targets)
    model.train()
    return preds, targets


def train(dl, model, loss_f, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    for e in trange(epochs):
        storer = defaultdict(list)
        loss_set = []
        for img, label in dl:
            img = img.cuda()
            recon_batch, latent_dist, latent_sample = model(img)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()


            loss_set.append(loss.item())
        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)
        storer['loss'] = np.mean(loss_set)

        if e % 100 == 0:
            wandb.log({
                'recon': wandb.Image(recon_batch[0, 0]),
                'img': wandb.Image(img[0, 0])
            })
        storer['epoch'] = e
        wandb.log(storer, sync=False)
    return storer

def inversion(seq):
    s=0
    for i in range(1,len(seq)):
        for j in range(i):
            if seq[j]>seq[i]:
                s+=1
    return s

# generate data
img_id = config.img_id
epochs = config.epochs
dim = config.dimension
beta = config.beta

img = torch.load(f'patterns/{img_id}.pat')
dataset = data_generator.Scaling(img)
dl = DataLoader(dataset,
                num_workers=0, batch_size=50, shuffle=True)

vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim)
vae.cuda();
vae.train()

# wandb.watch(vae, log_freq=10)
iterations = len(dl) * epochs
anneal = get_anneal('constant', iterations, 1, 1)

loss_f = get_loss_s(config.loss, anneal, config.beta, len(dl.dataset))

storer = train(dl, vae, loss_f, epochs)

kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' == k[:3]])
_, index = kl_loss.sort(descending=True)

vae.eval()
vae.cpu()

fig = plt_sample_traversal(dataset.imgs[:1], vae, 7, range(dim), r=2)


# inversion
_,params, z = vae(dataset.imgs)
mu=params[0].data.flatten()
order = mu.argsort()
inv = inversion(order)

# MI
log_pz, log_qz, log_prod_qzi, log_q_zCx = _get_log_pz_qz_prodzi_qzCx(z,
                                                                     params,
                                                                     len(dataset.imgs))
# I[z;x] = KL[q(z,x)||q(x)q(z)] = E_x[KL[q(z|x)||q(z)]]
mi_loss = (log_q_zCx - log_qz).mean().data

wandb.log({'traversal': wandb.Image(fig),
           'inversion':inv,
           'inv_entropy':np.log(inv),
           'mi_entropy':mi_loss.item()})

fig = plt.figure()
plt.plot(mu)
plt.show()